package com.alibaba.datax.core.transport.channel.memory;

import com.alibaba.datax.common.element.Record;
import com.alibaba.datax.common.exception.DataXException;
import com.alibaba.datax.common.util.Configuration;
import com.alibaba.datax.core.transport.channel.Channel;
import com.alibaba.datax.core.transport.record.TerminateRecord;
import com.alibaba.datax.core.util.FrameworkErrorCode;
import com.alibaba.datax.core.util.container.CoreConstant;

import java.util.Collection;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;

/**
 * 内存Channel的具体实现，底层其实是一个ArrayBlockingQueue
 */
public class MemoryChannel extends Channel {

  /**
   * 一次从Channel的pull的数据条数
   */
  private int bufferSize;

  /**
   * Channel里面保存的数据大小
   */
  private AtomicInteger memoryBytes = new AtomicInteger(0);

  /**
   * 存放记录的queue
   */
  private ArrayBlockingQueue<Record> queue;

  /**
   * 递归锁
   */
  private ReentrantLock lock;

  private Condition notInsufficient, notEmpty;

  public MemoryChannel(final Configuration configuration) {
    super(configuration);
    this.queue = new ArrayBlockingQueue<Record>(this.getCapacity());
    this.bufferSize = configuration.getInt(CoreConstant.DATAX_CORE_TRANSPORT_EXCHANGER_BUFFERSIZE);

    lock = new ReentrantLock();
    notInsufficient = lock.newCondition();
    notEmpty = lock.newCondition();
  }

  @Override
  public void close() {
    super.close();
    try {
      this.queue.put(TerminateRecord.get());
    } catch (InterruptedException ex) {
      Thread.currentThread().interrupt();
    }
  }

  @Override
  public void clear() {
    this.queue.clear();
  }

  @Override
  protected void doPush(Record r) {
    try {
      long startTime = System.nanoTime();
      // ArrayBlockingQueue提供了阻塞的put方法，写入数据
      this.queue.put(r);
      // 记录写入push花费的时间
      waitWriterTime += System.nanoTime() - startTime;
      // 更新Channel里数据的字节数
      memoryBytes.addAndGet(r.getMemorySize());
    } catch (InterruptedException ex) {
      Thread.currentThread().interrupt();
    }
  }

  @Override
  protected void doPushAll(Collection<Record> rs) {
    try {
      // 获取锁
      lock.lockInterruptibly();
      long startTime = System.nanoTime();
      int bytes = getRecordBytes(rs);
      while (memoryBytes.get() + bytes > this.byteCapacity || rs.size() > this.queue
          .remainingCapacity()) {
        // 如果新增数据，会造成数据字节数超过指定容量， 或者超过了queue的容量，就会一直等待notInsufficient信号
        notInsufficient.await(200L, TimeUnit.MILLISECONDS);
      }
      // 向queue里添加数据
      this.queue.addAll(rs);
      // 更新push的时间
      waitWriterTime += System.nanoTime() - startTime;
      // 更新数据的字节数
      memoryBytes.addAndGet(bytes);
      // 通知可以pull数据的信号
      notEmpty.signalAll();
    } catch (InterruptedException e) {
      throw DataXException.asDataXException(FrameworkErrorCode.RUNTIME_ERROR, e);
    } finally {
      lock.unlock();
    }
  }

  @Override
  protected Record doPull() {
    try {
      long startTime = System.nanoTime();
      // ArrayBlockingQueue提供了阻塞的take方法，读取入数据
      Record r = this.queue.take();
      // 记录写入pull花费的时间
      waitReaderTime += System.nanoTime() - startTime;
      // 更新Channel里数据的字节数
      memoryBytes.addAndGet(-r.getMemorySize());
      return r;
    } catch (InterruptedException e) {
      Thread.currentThread().interrupt();
      throw new IllegalStateException(e);
    }
  }


  @Override
  protected void doPullAll(Collection<Record> rs) {
    assert rs != null;
    rs.clear();
    try {
      long startTime = System.nanoTime();
      lock.lockInterruptibly();
      // 从queue里面取出数据，最多bufferSize条
      while (this.queue.drainTo(rs, bufferSize) <= 0) {
        // 如果queue里面没有数据，就等待notEmpty信号
        notEmpty.await(200L, TimeUnit.MILLISECONDS);
      }
      waitReaderTime += System.nanoTime() - startTime;
      int bytes = getRecordBytes(rs);
      memoryBytes.addAndGet(-bytes);
      notInsufficient.signalAll();
    } catch (InterruptedException e) {
      throw DataXException.asDataXException(FrameworkErrorCode.RUNTIME_ERROR, e);
    } finally {
      lock.unlock();
    }
  }

  private int getRecordBytes(Collection<Record> rs) {
    int bytes = 0;
    for (Record r : rs) {
      bytes += r.getMemorySize();
    }
    return bytes;
  }

  @Override
  public int size() {
    return this.queue.size();
  }

  @Override
  public boolean isEmpty() {
    return this.queue.isEmpty();
  }

}
